package com.stars.easyms.base.trace;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.stars.easyms.base.constant.EasyMsCommonConstants;
import com.stars.easyms.base.constant.HttpHeaderConstants;
import com.stars.easyms.base.bean.EasyMsRequestEntity;
import com.stars.easyms.base.util.ClassUtil;
import com.stars.easyms.base.util.PlumeLogUtil;
import lombok.AllArgsConstructor;
import org.slf4j.MDC;
import org.springframework.lang.Nullable;

import java.util.concurrent.atomic.AtomicInteger;

import static com.stars.easyms.base.constant.EasyMsCommonConstants.ASYNC_ID_CONNECTION_SYMBOL;

/**
 * <p>className: EasyMsTraceSynchronizationManager</p>
 * <p>description: EasyMs的traceId、用户同步管理器</p>
 *
 * @author guoguifang
 * @date 2020-08-20 20:38
 * @since 1.6.1
 */
public final class EasyMsTraceSynchronizationManager {

    private static boolean plumeLogExist = ClassUtil.isExist(EasyMsCommonConstants.PLUME_LOG_TRACE_ID_CLASS_NAME);

    private static final ThreadLocal<TraceBean> TRACE_BEAN_THREAD_LOCAL = new TransmittableThreadLocal<>();

    private static final ThreadLocal<AsyncBean> ASYNC_TRACE_BEAN_THREAD_LOCAL = new TransmittableThreadLocal<>();

    private static final ThreadLocal<EasyMsRequestEntity> REQUEST_ENTITY_THREAD_LOCAL = new TransmittableThreadLocal<>();

    private static final ThreadLocal<UserBean> EASY_MS_REST_USER_INFO_THREAD_LOCAL = new TransmittableThreadLocal<>();

    /**
     * 把traceId和requestId放入日志线程本地变量中
     */
    public static void setTraceInfo(String traceId, String requestId) {
        String mdcTraceInfo = "[traceId: " + traceId + "]-[requestId: " + requestId + "]";
        TRACE_BEAN_THREAD_LOCAL.set(new TraceBean(traceId, requestId, mdcTraceInfo));
        MDC.put(HttpHeaderConstants.TRACE_KEY, mdcTraceInfo);
        if (plumeLogExist) {
            PlumeLogUtil.setTraceId(traceId);
        }
    }

    public static void setTraceId(String traceId) {
        String mdcTraceInfo = "[traceId: " + traceId + "]";
        TRACE_BEAN_THREAD_LOCAL.set(new TraceBean(traceId, null, mdcTraceInfo));
        MDC.put(HttpHeaderConstants.TRACE_KEY, mdcTraceInfo);
        if (plumeLogExist) {
            PlumeLogUtil.setTraceId(traceId);
        }
    }

    @Nullable
    public static String getTraceId() {
        TraceBean traceBean = TRACE_BEAN_THREAD_LOCAL.get();
        if (traceBean != null) {
            return traceBean.traceId;
        }
        if (plumeLogExist) {
            return PlumeLogUtil.getTraceId();
        }
        return null;
    }

    @Nullable
    public static String getRequestId() {
        TraceBean traceBean = TRACE_BEAN_THREAD_LOCAL.get();
        if (traceBean != null) {
            return traceBean.requestId;
        }
        return null;
    }

    public static void setAsyncId(String currAsyncId) {
        AsyncBean currAsyncBean = new AsyncBean();
        currAsyncBean.asyncId = currAsyncId;
        String mdcTraceInfo = "[asyncId: " + currAsyncId + "]";
        TraceBean traceBean = TRACE_BEAN_THREAD_LOCAL.get();
        if (traceBean != null) {
            mdcTraceInfo = traceBean.mdcTraceInfo + "-" + mdcTraceInfo;
        }
        ASYNC_TRACE_BEAN_THREAD_LOCAL.set(currAsyncBean);
        MDC.put(HttpHeaderConstants.TRACE_KEY, mdcTraceInfo);
    }

    public static void entryAsyncThread(String traceId) {
        AsyncBean parentAsyncBean = ASYNC_TRACE_BEAN_THREAD_LOCAL.get();
        String currAsyncId = parentAsyncBean == null ? traceId + ASYNC_ID_CONNECTION_SYMBOL + 1 :
                parentAsyncBean.asyncId + ASYNC_ID_CONNECTION_SYMBOL + parentAsyncBean.subBeanSeqAlloc.incrementAndGet();
        setAsyncId(currAsyncId);
    }

    @Nullable
    public static String getAsyncId() {
        AsyncBean asyncBean = ASYNC_TRACE_BEAN_THREAD_LOCAL.get();
        if (asyncBean != null) {
            return asyncBean.asyncId;
        }
        return null;
    }

    public static void setRequestEntity(EasyMsRequestEntity easyMsRequestEntity) {
        REQUEST_ENTITY_THREAD_LOCAL.set(easyMsRequestEntity);
    }

    @Nullable
    public static EasyMsRequestEntity getRequestEntity() {
        return REQUEST_ENTITY_THREAD_LOCAL.get();
    }

    public static void setUserInfo(String userInfo, String decodedUserInfo) {
        EASY_MS_REST_USER_INFO_THREAD_LOCAL.set(new UserBean(userInfo, decodedUserInfo));
    }

    @Nullable
    public static String getCurrentUserInfo() {
        UserBean userBean = EASY_MS_REST_USER_INFO_THREAD_LOCAL.get();
        if (userBean != null) {
            return userBean.userInfo;
        }
        return null;
    }

    @Nullable
    public static String getDecodedCurrentUserInfo() {
        UserBean userBean = EASY_MS_REST_USER_INFO_THREAD_LOCAL.get();
        if (userBean != null) {
            return userBean.decodedUserInfo;
        }
        return null;
    }

    public static void clearAsyncId() {
        ASYNC_TRACE_BEAN_THREAD_LOCAL.remove();
    }

    public static void clearTraceInfo() {
        // 清除日志线程本地变量中的traceId和requestId
        MDC.remove(HttpHeaderConstants.TRACE_KEY);
        TRACE_BEAN_THREAD_LOCAL.remove();
        if (plumeLogExist) {
            PlumeLogUtil.remove();
        }
        clearAsyncId();
    }

    public static void clear() {
        EASY_MS_REST_USER_INFO_THREAD_LOCAL.remove();
        REQUEST_ENTITY_THREAD_LOCAL.remove();
        clearTraceInfo();
    }

    public static void capture(EasyMsTraceBean easyMsTraceBean) {
        TraceBean traceBean = TRACE_BEAN_THREAD_LOCAL.get();
        if (traceBean == null) {
            if (plumeLogExist) {
                easyMsTraceBean.setTraceId(PlumeLogUtil.getTraceId());
            }
            return;
        }
        easyMsTraceBean.setTraceId(traceBean.traceId);
        easyMsTraceBean.setRequestId(traceBean.requestId);
        easyMsTraceBean.setAsyncId(getAsyncId());
    }

    public static void restore(EasyMsTraceBean easyMsTraceBean) {
        String traceId;
        if (easyMsTraceBean == null || (traceId = easyMsTraceBean.getTraceId()) == null) {
            return;
        }
        String requestId = easyMsTraceBean.getRequestId();
        if (requestId != null) {
            setTraceInfo(traceId, requestId);
        } else {
            setTraceId(traceId);
        }
        String asyncId = easyMsTraceBean.getAsyncId();
        if (asyncId != null) {
            setAsyncId(asyncId);
        }
    }

    @AllArgsConstructor
    private static class TraceBean {

        private String traceId;

        private String requestId;

        private String mdcTraceInfo;
    }

    private static class AsyncBean {

        private AtomicInteger subBeanSeqAlloc = new AtomicInteger(0);

        private String asyncId;
    }

    @AllArgsConstructor
    private static class UserBean {

        private String userInfo;

        private String decodedUserInfo;
    }

    private EasyMsTraceSynchronizationManager() {
    }

}